package com.virjar.xposed_extention;

import android.util.Log;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;

import de.robv.android.xposed.XC_MethodHook;
import de.robv.android.xposed.XposedBridge;
import de.robv.android.xposed.XposedHelpers;

/**
 * Created by virjar on 2018/4/11.<br>
 * 方便的代码植入封装
 */
public class ClassLoadMonitor {
    public interface OnClassLoader {
        void onClassLoad(Class<?> clazz);
    }

    private static Set<ClassLoader> lowPriorityClassLoader = Sets.newConcurrentHashSet();
    private static ConcurrentMap<String, Set<OnClassLoader>> callBacks = Maps.newConcurrentMap();
    private static Set<ClassLoader> hookedClassLoader = Sets.newConcurrentHashSet();
    private static ThreadLocal<Boolean> isCallBackRunning = new ThreadLocal<>();

    static {
        enableClassMonitor();
    }

    public static void setUp() {
        //do nothing
    }

    private static void collectClassLoader(ClassLoader classLoader) {
        if (classLoader == null) {
            return;
        }
        if (lowPriorityClassLoader.contains(classLoader)) {
            return;
        }
        if (hookedClassLoader.contains(classLoader)) {
            return;
        }
        hookedClassLoader.add(classLoader);
//        while ((classLoader = classLoader.getParent()) != null) {
//            collectClassLoader(classLoader);
//        }
        fireCallBack();
    }

    public static void addLowPriorityClassLoader(ClassLoader classLoader) {
        if (lowPriorityClassLoader.contains(classLoader)) {
            return;
        }
        if (hookedClassLoader.contains(classLoader)) {
            hookedClassLoader.remove(classLoader);
        }
        lowPriorityClassLoader.add(classLoader);
        fireCallBack();
    }

    private static void enableClassMonitor() {
        //要hook所有子类的方法实现
        XposedHelpers.findAndHookConstructor(ClassLoader.class, new SingletonXC_MethodHook() {

            @Override
            protected void afterHookedMethod(MethodHookParam param) throws Throwable {
                collectClassLoader((ClassLoader) param.thisObject);
            }

        });

        //隐式加载入口
        XposedHelpers.findAndHookMethod(Class.class, "forName", String.class, boolean.class, ClassLoader.class, new SingletonXC_MethodHook() {
            @Override
            protected void afterHookedMethod(MethodHookParam param) throws Throwable {
                collectClassLoader((ClassLoader) param.args[2]);
            }
        });

        if (SharedObject.loadPackageParam != null) {
            collectClassLoader(SharedObject.loadPackageParam.classLoader);
        }
        addLowPriorityClassLoader(ClassLoadMonitor.class.getClassLoader());
        collectClassLoader(Thread.currentThread().getContextClassLoader());
        //do not need to fire call back for masterClassLoader

    }


    private static void fireCallBack() {
        Boolean aBoolean = isCallBackRunning.get();
        if (aBoolean != null && aBoolean) {
            return;
        }
        if (callBacks.size() == 0) {
            return;
        }
        Set<ClassLoader> testClassLoader = Sets.newHashSet();
        testClassLoader.addAll(hookedClassLoader);
        testClassLoader.addAll(lowPriorityClassLoader);
        if (testClassLoader.size() == 0) {
            return;
        }

        isCallBackRunning.set(true);
        //监听函数，不允许重入，否则会有类加载不完整触发回调的可能
        try {

            Set<String> succeedCallBack = Sets.newHashSet();
            for (String monitorClassName : callBacks.keySet()) {
                for (ClassLoader classLoader : testClassLoader) {
                    try {
                        Class<?> aClass = classLoader.loadClass(monitorClassName);
                        Collection<OnClassLoader> onClassLoaders = callBacks.get(monitorClassName);
                        for (OnClassLoader onClassLoader : onClassLoaders) {
                            try {
                                onClassLoader.onClassLoad(aClass);
                            } catch (Throwable throwable) {
                                Log.e("weijia", "error when callback for class load monitor", throwable);
                            }
                        }
                        succeedCallBack.add(monitorClassName);
                        break;
                    } catch (Throwable throwable) {
                        //ignore
                    }
                }
            }
            for (String className : succeedCallBack) {
                callBacks.get(className).clear();
            }
        } finally {
            isCallBackRunning.remove();
        }
    }

    /**
     * 增加某个class的加载监听，注意该方法不做重入消重工作，需要调用方自己实现回调消重逻辑。<br>
     * 该函数将会尽可能早的的回调到业务方，常常用来注册挂钩函数（这样可以实现挂钩函数注册过晚导致感兴趣的逻辑拦截失败）
     *
     * @param className     将要监听的className，如果存在多个class name相同的类，存在于不同的classloader，可能会导致监听失败
     * @param onClassLoader 监听的回调
     */
    public static void addClassLoadMonitor(String className, OnClassLoader onClassLoader) {
        collectClassLoader(Thread.currentThread().getContextClassLoader());
        Set<OnClassLoader> onClassLoaders = callBacks.get(className);
        if (onClassLoaders == null) {
            onClassLoaders = Sets.newConcurrentHashSet();
            //putIfAbsent maybe null
            callBacks.putIfAbsent(className, onClassLoaders);
            onClassLoaders = callBacks.get(className);
        }
        onClassLoaders.add(onClassLoader);
        fireCallBack();
    }

    private static Map<String, Class<?>> classCache = Maps.newConcurrentMap();

    /**
     * 尝试加载一个class，无需感知classloader，的存在
     *
     * @param className className
     * @return class对象，如果无法加载，返回null
     */
    public static Class<?> tryLoadClass(String className) {
        collectClassLoader(Thread.currentThread().getContextClassLoader());
        Class<?> ret = classCache.get(className);
        if (ret != null) {
            if (hookedClassLoader.contains(ret.getClassLoader())) {
                return ret;
            }
            Class<?> candidateRet = tryLoadClassInternal(className, false);
            if (candidateRet != null) {
                classCache.put(className, ret);
                return candidateRet;
            }
            return ret;
        }
        ret = tryLoadClassInternal(className, true);
        if (ret != null) {
            classCache.put(className, ret);
        }
        return ret;
    }

    private static Class<?> tryLoadClassInternal(String className, boolean scanLowPriority) {
        Set<ClassLoader> testClassLoader = Sets.newHashSet();
        testClassLoader.addAll(hookedClassLoader);

        if (scanLowPriority) {
            testClassLoader.addAll(lowPriorityClassLoader);
        }

        for (ClassLoader classLoader : testClassLoader) {
            try {
                Class<?> aClass = ReflectUtil.findClassIfExists(className, classLoader);
                if (aClass != null) {
                    return aClass;
                }
            } catch (Throwable throwable) {
                // 可能有虚拟机相关的class加载失败异常，所以这里catch Throwable
                // ignore
            }
        }
        return ReflectUtil.findClassIfExists(className, null);
    }

    public static void findAndHookMethod(String className, final String methodName, final Object... parameterTypesAndCallback) {
        addClassLoadMonitor(className, new OnClassLoader() {
            @Override
            public void onClassLoad(Class<?> clazz) {
                XposedHelpers.findAndHookMethod(clazz, methodName, parameterTypesAndCallback);
            }
        });
    }

    public static void findAndHookMethodWithSupper(String className, final String methodName, final Object... parameterTypesAndCallback) {
        addClassLoadMonitor(className, new OnClassLoader() {
            @Override
            public void onClassLoad(Class<?> clazz) {
                Throwable t = null;
                while (clazz != Object.class) {
                    try {
                        XposedHelpers.findAndHookMethod(clazz, methodName, parameterTypesAndCallback);
                        return;
                    } catch (Throwable throwable) {
                        if (t == null) {
                            t = throwable;
                        }
                        clazz = clazz.getSuperclass();
                    }
                }
                throw new IllegalStateException(t);
            }
        });
    }

    public static void hookAllMethod(String className, XC_MethodHook callback) {
        hookAllMethod(className, null, callback);
    }

    public static void hookAllMethod(final String className, final String methodName, final XC_MethodHook callback) {
        addClassLoadMonitor(className, new OnClassLoader() {
            @Override
            public void onClassLoad(Class<?> clazz) {
                if (Modifier.isInterface(clazz.getModifiers())) {
                    Log.e("weijia", "the class : {" + clazz.getName() + "} is interface can not hook any method!!");
                    return;
                }
                for (Method method : clazz.getDeclaredMethods()) {
                    if (methodName != null && !method.getName().equals(methodName)) {
                        continue;
                    }
                    if (Modifier.isAbstract(method.getModifiers())) {
                        continue;
                    }
                    XposedBridge.hookMethod(method, callback);
                }
            }
        });
    }

    public static void hookAllConstructor(String className, final XC_MethodHook callback) {
        addClassLoadMonitor(className, new OnClassLoader() {
            @Override
            public void onClassLoad(Class<?> clazz) {
                XposedBridge.hookAllConstructors(clazz, callback);
            }
        });
    }
}